{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "906e07f6e562"
   },
   "outputs": [],
   "source": [
    "# @title Copyright 2022 The Cirq Developers\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8bbd73c03ac2"
   },
   "source": [
    "# Custom Transformers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "25eb74f260d6"
   },
   "source": [
    "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://quantumai.google/cirq/transform/custom_transformers\"><img src=\"https://quantumai.google/site-assets/images/buttons/quantumai_logo_1x.png\" />View on QuantumAI</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/quantumlib/Cirq/blob/main/docs/transform/custom_transformers.ipynb\"><img src=\"https://quantumai.google/site-assets/images/buttons/colab_logo_1x.png\" />Run in Google Colab</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://github.com/quantumlib/Cirq/blob/main/docs/transform/custom_transformers.ipynb\"><img src=\"https://quantumai.google/site-assets/images/buttons/github_logo_1x.png\" />View source on GitHub</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a href=\"https://storage.googleapis.com/tensorflow_docs/Cirq/docs/transform/custom_transformers.ipynb\"><img src=\"https://quantumai.google/site-assets/images/buttons/download_icon_1x.png\" />Download notebook</a>\n",
    "  </td>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6d36e471ec63"
   },
   "source": [
    "The [Transformers](./transformers.ipynb) page introduced what a transformer is, what transformers are available in Cirq, and how to create a simple one as a composite of others. This page covers the details necessary for creating more nuanced custom transformers, including `cirq.TransformerContext`, primitives and decompositions."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8d4ac441b1fc"
   },
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "846b32703c5c"
   },
   "outputs": [],
   "source": [
    "try:\n",
    "    import cirq\n",
    "except ImportError:\n",
    "    print(\"installing cirq...\")\n",
    "    !pip install --quiet cirq\n",
    "    import cirq\n",
    "\n",
    "    print(\"installed cirq.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6da0Ra7bz2St"
   },
   "source": [
    "## `cirq.TRANSFORMER` API and `@cirq.transformer` decorator\n",
    "\n",
    "Any callable that satisfies the `cirq.TRANSFORMER` contract, i.e. takes a `cirq.AbstractCircuit` and `cirq.TransformerContext` and returns a transformed `cirq.AbstractCircuit`, is a valid transformer in Cirq. \n",
    "\n",
    "You can create a custom transformer by simply decorating a class/method, that satisfies the above contract, with `@cirq.transformer` decorator. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "e2893a817870"
   },
   "outputs": [],
   "source": [
    "@cirq.transformer\n",
    "def reverse_circuit(circuit, *, context=None):\n",
    "    \"\"\"Transformer to reverse the input circuit.\"\"\"\n",
    "    return circuit[::-1]\n",
    "\n",
    "\n",
    "@cirq.transformer\n",
    "class SubstituteGate:\n",
    "    \"\"\"Transformer to substitute `source` gates with `target` in the input circuit.\"\"\"\n",
    "\n",
    "    def __init__(self, source, target):\n",
    "        self._source = source\n",
    "        self._target = target\n",
    "\n",
    "    def __call__(self, circuit, *, context=None):\n",
    "        batch_replace = []\n",
    "        for i, op in circuit.findall_operations(lambda op: op.gate == self._source):\n",
    "            batch_replace.append((i, op, self._target.on(*op.qubits)))\n",
    "        transformed_circuit = circuit.unfreeze(copy=True)\n",
    "        transformed_circuit.batch_replace(batch_replace)\n",
    "        return transformed_circuit\n",
    "\n",
    "\n",
    "# Build your circuit\n",
    "q = cirq.NamedQubit(\"q\")\n",
    "circuit = cirq.Circuit(\n",
    "    cirq.X(q), cirq.CircuitOperation(cirq.FrozenCircuit(cirq.X(q), cirq.Y(q))), cirq.Z(q)\n",
    ")\n",
    "# Transform and compare the circuits.\n",
    "substitute_gate = SubstituteGate(cirq.X, cirq.S)\n",
    "print(\"Original Circuit:\", circuit, \"\\n\", sep=\"\\n\")\n",
    "print(\"Reversed Circuit:\", reverse_circuit(circuit), \"\\n\", sep=\"\\n\")\n",
    "print(\"Substituted Circuit:\", substitute_gate(circuit), sep=\"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "OLJrXdyTCGev"
   },
   "source": [
    "## `cirq.TransformerContext` to store common configurable options\n",
    "`cirq.TransformerContext` is a dataclass that stores common configurable options for all transformers. All cirq transformers should accept the transformer context as an optional keyword argument. \n",
    "\n",
    "The `@cirq.transformer` decorator can inspect the `cirq.TransformerContext` argument and automatically append useful functionality, like support for automated logging and recursively running the transformer on nested sub-circuits.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5OuMLlNAFPcW"
   },
   "source": [
    "### `cirq.TransformerLogger` and support for automated logging\n",
    "The `cirq.TransformerLogger` class is used to log the actions of a transformer on an input circuit. `@cirq.transformer` decorator automatically adds support for logging the initial and final circuits for each transformer step. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "1-1r6Ha9F1c3"
   },
   "outputs": [],
   "source": [
    "# Note that you want to log the steps.\n",
    "context = cirq.TransformerContext(logger=cirq.TransformerLogger())\n",
    "# Transform the circuit.\n",
    "transformed_circuit = reverse_circuit(circuit, context=context)\n",
    "transformed_circuit = substitute_gate(transformed_circuit, context=context)\n",
    "# Show the steps.\n",
    "context.logger.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6cdf0a2644a5"
   },
   "source": [
    "Neither of the custom transformers, `reverse_circuit` or `substitute_gate`, had any explicit support for a logger present in the `context` argument, but the decorator was able to use it anyways. \n",
    "\n",
    "If your custom transformer calls another transformer as part of it, then that transformer should log its behavior as long as you pass the `context` object to it. All Cirq-provided transformers do this. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "b0f12bdea3ba"
   },
   "outputs": [],
   "source": [
    "@cirq.transformer\n",
    "def reverse_and_substitute(circuit, context=None):\n",
    "    reversed_circuit = reverse_circuit(circuit, context=context)\n",
    "    reversed_and_substituted_circuit = substitute_gate(reversed_circuit, context=context)\n",
    "    return reversed_and_substituted_circuit\n",
    "\n",
    "\n",
    "# Note that you want to log the steps.\n",
    "context = cirq.TransformerContext(logger=cirq.TransformerLogger())\n",
    "# Transform the circuit.\n",
    "transformed_circuit = reverse_and_substitute(circuit, context=context)\n",
    "# Show the steps.\n",
    "context.logger.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4bc4a9b5258f"
   },
   "source": [
    "Note: Each transformer that is run on the circuit is indexed globally, over all transformers run through the logger. However, when one transformer (`reverse_and_substitute`) runs other transformers as a subprocess (`reverse_circuit` and `SubstituteGate`), the constitutent transformers are indented to show this relationship. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UvFdvxQ6MZD9"
   },
   "source": [
    "### Support for `deep=True`\n",
    "You can call `@cirq.transformer(add_deep_support=True)` to automatically add the functionality of recursively running the custom transformer on circuits wrapped inside `cirq.CircuitOperation`. The recursive execution behavior of the transformer can then be controlled by setting `deep=True` in the transformer context. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BobP3UFBJhP1"
   },
   "outputs": [],
   "source": [
    "@cirq.transformer(add_deep_support=True)\n",
    "def reverse_circuit_deep(circuit, *, context=None):\n",
    "    \"\"\"Transformer to reverse the input circuit.\"\"\"\n",
    "    return circuit[::-1]\n",
    "\n",
    "\n",
    "@cirq.transformer(add_deep_support=True)\n",
    "class SubstituteGateDeep(SubstituteGate):\n",
    "    \"\"\"Transformer to substitute `source` gates with `target` in the input circuit.\"\"\"\n",
    "\n",
    "    pass\n",
    "\n",
    "\n",
    "# Note that you want to transform the CircuitOperations.\n",
    "context = cirq.TransformerContext(deep=True)\n",
    "# Transform and compare the circuits.\n",
    "substitute_gate_deep = SubstituteGateDeep(cirq.X, cirq.S)\n",
    "print(\"Original Circuit:\", circuit, \"\\n\", sep=\"\\n\")\n",
    "print(\n",
    "    \"Reversed Circuit with deep=True:\",\n",
    "    reverse_circuit_deep(circuit, context=context),\n",
    "    \"\\n\",\n",
    "    sep=\"\\n\",\n",
    ")\n",
    "print(\n",
    "    \"Substituted Circuit with deep=True:\", substitute_gate_deep(circuit, context=context), sep=\"\\n\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bDizgWMmPLq0"
   },
   "source": [
    "## Transformer Primitives and Decompositions\n",
    "\n",
    "If you need to perform more fundamental changes than just running other transformers in sequence (like `SubstituteGate` did with `cirq.Circuit.batch_replace`), Cirq provides circuit compilation primitives and gate decomposition utilities for doing so. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sXbcG1VdPRkm"
   },
   "source": [
    "### Moment preserving transformer primitives\n",
    "\n",
    "Cirq's transformer primitives are useful abstractions to implement common transformer patterns, while preserving the moment structure of input circuit. Some of the notable transformer primitives are:\n",
    "\n",
    "- **`cirq.map_operations`**: Applies local transformations on operations, by calling `map_func(op)` for each `op`.\n",
    "- **`cirq.map_moments`**: Applies local transformation on moments, by calling `map_func(m)` for each moment `m`.\n",
    "- **`cirq.merge_operations`**: Merges connected component of operations by iteratively calling `merge_func(op1, op2)`  for every pair of mergeable operations `op1` and `op2`.\n",
    "- **`cirq.merge_moments`**: Merges adjacent moments, from left to right, by iteratively calling `merge_func(m1, m2)` for adjacent moments `m1` and `m2`.\n",
    "\n",
    "An important property of these primitives is that they have support for common configurable options present in `cirq.TransformerContext`, such as `tags_to_ignore` and `deep`, as demonstrated in the example below. \n",
    "\n",
    "Note: Primitives support both the `deep` argument and the `tags_to_ignore` argument, like many transformers, so you can easily pass those values in from a `TransformContext`, if one is available."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BCXc94owfIJh"
   },
   "outputs": [],
   "source": [
    "@cirq.transformer\n",
    "def substitute_gate_using_primitives(circuit, *, context=None, source=cirq.X, target=cirq.S):\n",
    "    \"\"\"Transformer to substitute `source` gates with `target` in the input circuit.\n",
    "\n",
    "    The transformer is implemented using `cirq.map_operations` primitive and hence\n",
    "    has built-in support for\n",
    "      1. Recursively running the transformer on sub-circuits if `context.deep is True`.\n",
    "      2. Ignoring operations tagged with any of `context.tags_to_ignore`.\n",
    "    \"\"\"\n",
    "    return cirq.map_operations(\n",
    "        circuit,\n",
    "        map_func=lambda op, _: target.on(*op.qubits) if op.gate == source else op,\n",
    "        deep=context.deep if context else False,\n",
    "        tags_to_ignore=context.tags_to_ignore if context else (),\n",
    "    )\n",
    "\n",
    "\n",
    "# Build your circuit from x_y_x components.\n",
    "x_y_x = [cirq.X(q), cirq.Y(q), cirq.X(q).with_tags(\"ignore\")]\n",
    "circuit = cirq.Circuit(x_y_x, cirq.CircuitOperation(cirq.FrozenCircuit(x_y_x)), x_y_x)\n",
    "# Note that you want to transform the CircuitOperations and ignore tagged operations.\n",
    "context = cirq.TransformerContext(deep=True, tags_to_ignore=(\"ignore\",))\n",
    "# Compare the before and after circuits.\n",
    "print(\"Original Circuit:\", circuit, \"\\n\", sep=\"\\n\")\n",
    "print(\n",
    "    \"Substituted Circuit:\",\n",
    "    substitute_gate_using_primitives(circuit, context=context),\n",
    "    \"\\n\",\n",
    "    sep=\"\\n\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NzTPE_wGjWkJ"
   },
   "source": [
    "## Analytical Gate Decompositions\n",
    "\n",
    "Gate decomposition is the process of implementing / decomposing a given unitary `U` using only gates that belong to a specific target gateset. \n",
    "\n",
    "Cirq provides analytical decomposition methods, often based on [KAK Decomposition](https://arxiv.org/abs/quant-ph/0507171), to decompose one-, two-, and three-qubit unitary matrices into specific target gatesets. Some notable decompositions are:\n",
    "\n",
    "* **`cirq.single_qubit_matrix_to_pauli_rotations`**: Decomposes a single qubit matrix to ZPow/XPow/YPow rotations. \n",
    "* **`cirq.single_qubit_matrix_to_phased_x_z`**: Decomposes a single-qubit matrix to a PhasedX and Z gate.\n",
    "* **`cirq.two_qubit_matrix_to_sqrt_iswap_operations`**: Decomposes any two-qubit unitary matrix into ZPow/XPow/YPow/sqrt-iSWAP gates.\n",
    "* **`cirq.two_qubit_matrix_to_cz_operations`**: Decomposes any two-qubit unitary matrix into ZPow/XPow/YPow/CZ gates.\n",
    "* **`cirq.three_qubit_matrix_to_operations`**: Decomposes any three-qubit unitary matrix into CZ/CNOT and single qubit rotations.\n",
    "\n",
    "\n",
    "\n",
    "You can use these analytical decomposition methods to build transformers which can rewrite a given circuit using only gates from the target gateset. This example again uses the transformer primitives to support recursive execution and `ignore` tagging."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Pb7sXwEFcIxB"
   },
   "outputs": [],
   "source": [
    "@cirq.transformer\n",
    "def convert_to_cz_target(circuit, *, context=None, atol=1e-8, allow_partial_czs=True):\n",
    "    \"\"\"Transformer to rewrite the given circuit using CZs + 1-qubit rotations.\n",
    "\n",
    "    Note that the transformer decomposes only operations on <= 2-qubits and is\n",
    "    presented as an illustration of using transformer primitives + analytical\n",
    "    decomposition methods.\n",
    "    \"\"\"\n",
    "\n",
    "    def map_func(op: cirq.Operation, _) -> cirq.OP_TREE:\n",
    "        if not (cirq.has_unitary(op) and cirq.num_qubits(op) <= 2):\n",
    "            return op\n",
    "        matrix = cirq.unitary(op)\n",
    "        qubits = op.qubits\n",
    "        if cirq.num_qubits(op) == 1:\n",
    "            g = cirq.single_qubit_matrix_to_phxz(matrix)\n",
    "            return g.on(*qubits) if g else []\n",
    "        return cirq.two_qubit_matrix_to_cz_operations(\n",
    "            *qubits, matrix, allow_partial_czs=allow_partial_czs, atol=atol\n",
    "        )\n",
    "\n",
    "    return cirq.map_operations_and_unroll(\n",
    "        circuit,\n",
    "        map_func,\n",
    "        deep=context.deep if context else False,\n",
    "        tags_to_ignore=context.tags_to_ignore if context else (),\n",
    "    )\n",
    "\n",
    "\n",
    "# Build the circuit from three versions of the same random component\n",
    "component = cirq.testing.random_circuit(qubits=3, n_moments=2, op_density=0.8, random_state=1234)\n",
    "component_operation = cirq.CircuitOperation(cirq.FrozenCircuit(component))\n",
    "# A normal component, a CircuitOperation version, and a ignore-tagged CircuitOperation version\n",
    "circuit = cirq.Circuit(component, component_operation, component_operation.with_tags('ignore'))\n",
    "# Note that you want to transform the CircuitOperations, ignore tagged operations, and log the steps.\n",
    "context = cirq.TransformerContext(\n",
    "    deep=True, tags_to_ignore=(\"ignore\",), logger=cirq.TransformerLogger()\n",
    ")\n",
    "# Run your transformer.\n",
    "converted_circuit = convert_to_cz_target(circuit, context=context)\n",
    "# Ensure that the resulting circuit is equivalent.\n",
    "cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(circuit, converted_circuit)\n",
    "# Show the steps executed.\n",
    "context.logger.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jE7XXZs7bmTJ"
   },
   "source": [
    "## Heuristic Gate Decompositions\n",
    "Cirq also provides heuristic methods for decomposing any two qubit unitary matrix in terms of any specified two qubit target unitary + single qubit rotations. These methods are useful when accurate analytical decompositions for the target unitary are not known or when gate decomposition fidelity (i.e. accuracy of decomposition) can be traded off against decomposition depth (i.e. number of 2q gates in resulting decomposition) to achieve a higher overall gate fidelity. \n",
    "\n",
    "\n",
    "See the following resources for more details on heuristic gate decomposition:\n",
    "\n",
    "* **`cirq.two_qubit_gate_product_tabulation`**\n",
    "* **https://arxiv.org/pdf/2106.15490.pdf**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aTNKFcPesNpy"
   },
   "source": [
    "# Summary\n",
    "Cirq provides a flexible and powerful framework to\n",
    "* Use built-in transformer primitives and analytical tools to create powerful custom transformers both from scratch and by composing existing transformers.\n",
    "* Easily integrate custom transformers with built-in infrastructure to augment functionality like automated logging, recursive execution on sub-circuits, support for no-compile tags etc."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "custom_transformers.ipynb",
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
